import torch

# Parameters
n0 = 20
n1 = 20
M = 3
N = 3
num_points = 100
chunk_size = 20

# Generate random tensor1 of shape [n0, n0, n1]
tensor1 = torch.randn(n0, n0, n1)

# Flatten tensor1 to tensor2
tensor2 = tensor1.view(-1)
print(tensor2.shape)

# Define the functions
def c_old(u_values, M=3, N=3, num_points=100, chunk_size=20):
    device = u_values.device
    t_values = torch.linspace(-M, N, num_points).to(device).unsqueeze(0)
    b = -t_values / 2
    a = 1 / 2
    tensor_0 = (b / a).to(device)
    dt = (M + N) / num_points

    result = torch.zeros_like(u_values).to(device)
    u_values_expanded = u_values.unsqueeze(1)

    for start in range(0, u_values.size(0), chunk_size):
        end = min(start + chunk_size, u_values.size(0))
        u_chunk = u_values_expanded[start:end]

        positive_mask = u_chunk >= 0
        if positive_mask.any():
            mask_positive = (t_values > 0) & (t_values <= u_chunk)
            cumsum_positive = torch.cumsum(tensor_0 * mask_positive.float(), dim=1) * dt
            result[start:end][positive_mask.squeeze()] = cumsum_positive[:, -1][positive_mask.squeeze()]

        negative_mask = u_chunk < 0
        if negative_mask.any():
            mask_negative = (t_values <= 0) & (t_values > u_chunk)
            cumsum_negative = torch.cumsum(tensor_0 * mask_negative.float(), dim=1) * dt
            result[start:end][negative_mask.squeeze()] = -cumsum_negative[:, -1][negative_mask.squeeze()]

    return result

def c_new(u_values, M=3, N=3, num_points=100):
    device = u_values.device
    original_shape = u_values.shape
    u_values_flat = u_values.view(-1, u_values.size(-1))
    
    t_values = torch.linspace(-M, N, num_points).to(device).unsqueeze(0)
    b = -t_values / 2
    a = 1 / 2
    tensor_0 = (b / a).to(device)
    dt = (M + N) / num_points
    
    u_values_expanded = u_values_flat.unsqueeze(2)
    
    positive_mask = u_values_expanded >= 0
    negative_mask = u_values_expanded < 0
    
    mask_positive = (t_values > 0) & (t_values <= u_values_expanded)
    mask_negative = (t_values <= 0) & (t_values > u_values_expanded)
    
    cumsum_positive = torch.cumsum(tensor_0 * mask_positive.float(), dim=-1) * dt
    cumsum_negative = torch.cumsum(tensor_0 * mask_negative.float(), dim=-1) * dt
    
    result_positive = cumsum_positive[:, :, -1]
    result_negative = -cumsum_negative[:, :, -1]
    
    result = torch.where(positive_mask[:, :, -1], result_positive, result_negative)
    
    result = result.view(original_shape)
    return result

# Calculate results
result_old = c_old(tensor2)
result_new = c_new(tensor1)
result_new_flat = result_new.view(-1)

# Check if results are the same
are_equal = torch.allclose(result_old, result_new_flat)
print(f"Are the results equal? {are_equal}")
print(f"Old shape: {result_old.shape}")
print(f"new shape: {result_new.shape}")
